{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9f0f7162-2e55-48c7-8f12-2cb7ce857677",
   "metadata": {},
   "outputs": [],
   "source": [
    "from models import *\n",
    "from utils import *\n",
    "from hifi_data_process import *\n",
    "import json\n",
    "import itertools\n",
    "from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence\n",
    "import gc\n",
    "from torch.cuda.amp import GradScaler, autocast\n",
    "import torch.optim as optim\n",
    "from tqdm import tqdm\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "67de5423-8ef9-4b77-9510-247f880be4b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_loader = get_DataLoader(batch_size = 16, num_workers = 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "66388101-59ed-44c4-a1af-ee24d6a5c17c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取一个批次的数据\n",
    "data_iter = iter(data_loader)\n",
    "batch = next(data_iter)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b240c604-ecb0-4fb3-b2e4-b6fc6260300a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['audio_label', 'sr', 'mel_spec'])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch[0].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "531a43b0-a00e-4864-a6b3-b3514b764ad5",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7231d7f4-2b1f-44a0-ab57-b500111ee7c0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c4dcc882-8eb1-4cc9-afe1-f9b8c79208b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:30: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\n",
      "  warnings.warn(\"torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\")\n"
     ]
    }
   ],
   "source": [
    "config = 'config_v3.json'\n",
    "with open(config) as f:\n",
    "    data = f.read()\n",
    "\n",
    "json_config = json.loads(data)\n",
    "\n",
    "h = AttrDict(json_config)\n",
    "generator = Generator(h).to(device)\n",
    "mpd = MultiPeriodDiscriminator().to(device)\n",
    "msd = MultiScaleDiscriminator().to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5134ff4a-faf5-444e-8aeb-a20a23d28088",
   "metadata": {},
   "outputs": [],
   "source": [
    "optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])\n",
    "optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),\n",
    "                            h.learning_rate, betas=[h.adam_b1, h.adam_b2])\n",
    "scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay)\n",
    "scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay)\n",
    "criterion = nn.MSELoss()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8289eeab-f8b9-4fe6-9b08-db24bc175eb0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b9c33d7c-7a50-435d-b004-f1344e9ba7c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_epoch = 1000\n",
    "stdout_interval = 5\n",
    "checkpoint_path = 'output'\n",
    "checkpoint_interval = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "dd40b7cb-dd2c-4f41-bbf0-6f8fa41f0d0b",
   "metadata": {},
   "outputs": [
    {
     "ename": "OutOfMemoryError",
     "evalue": "CUDA out of memory. Tried to allocate 214.00 MiB. GPU 0 has a total capacty of 11.76 GiB of which 115.19 MiB is free. Process 652436 has 11.64 GiB memory in use. Of the allocated memory 10.86 GiB is allocated by PyTorch, and 498.86 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mOutOfMemoryError\u001b[0m                          Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[14], line 17\u001b[0m\n\u001b[1;32m     14\u001b[0m y_g_hat \u001b[38;5;241m=\u001b[39m generator(x)\n\u001b[1;32m     16\u001b[0m optim_d\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 17\u001b[0m y_df_hat_r, y_df_hat_g, _, _ \u001b[38;5;241m=\u001b[39m \u001b[43mmpd\u001b[49m\u001b[43m(\u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_g_hat\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdetach\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     18\u001b[0m loss_disc_f, losses_disc_f_r, losses_disc_f_g \u001b[38;5;241m=\u001b[39m discriminator_loss(y_df_hat_r, y_df_hat_g)\n\u001b[1;32m     20\u001b[0m y_ds_hat_r, y_ds_hat_g, _, _ \u001b[38;5;241m=\u001b[39m msd(y, y_g_hat\u001b[38;5;241m.\u001b[39mdetach())\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/my_own_/train_my_hifi/models.py:181\u001b[0m, in \u001b[0;36mMultiPeriodDiscriminator.forward\u001b[0;34m(self, y, y_hat)\u001b[0m\n\u001b[1;32m    179\u001b[0m fmap_gs \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m    180\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, d \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdiscriminators):\n\u001b[0;32m--> 181\u001b[0m     y_d_r, fmap_r \u001b[38;5;241m=\u001b[39m \u001b[43md\u001b[49m\u001b[43m(\u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    182\u001b[0m     y_d_g, fmap_g \u001b[38;5;241m=\u001b[39m d(y_hat)\n\u001b[1;32m    183\u001b[0m     y_d_rs\u001b[38;5;241m.\u001b[39mappend(y_d_r)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/my_own_/train_my_hifi/models.py:154\u001b[0m, in \u001b[0;36mDiscriminatorP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    151\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mview(b, c, t \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mperiod, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mperiod)\n\u001b[1;32m    153\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m l \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconvs:\n\u001b[0;32m--> 154\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[43ml\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    155\u001b[0m     x \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mleaky_relu(x, LRELU_SLOPE)\n\u001b[1;32m    156\u001b[0m     fmap\u001b[38;5;241m.\u001b[39mappend(x)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1568\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1565\u001b[0m     bw_hook \u001b[38;5;241m=\u001b[39m hooks\u001b[38;5;241m.\u001b[39mBackwardHook(\u001b[38;5;28mself\u001b[39m, full_backward_hooks, backward_pre_hooks)\n\u001b[1;32m   1566\u001b[0m     args \u001b[38;5;241m=\u001b[39m bw_hook\u001b[38;5;241m.\u001b[39msetup_input_hook(args)\n\u001b[0;32m-> 1568\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1569\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks:\n\u001b[1;32m   1570\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m hook_id, hook \u001b[38;5;129;01min\u001b[39;00m (\n\u001b[1;32m   1571\u001b[0m         \u001b[38;5;241m*\u001b[39m_global_forward_hooks\u001b[38;5;241m.\u001b[39mitems(),\n\u001b[1;32m   1572\u001b[0m         \u001b[38;5;241m*\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks\u001b[38;5;241m.\u001b[39mitems(),\n\u001b[1;32m   1573\u001b[0m     ):\n\u001b[1;32m   1574\u001b[0m         \u001b[38;5;66;03m# mark that always called hook is run\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/conv.py:460\u001b[0m, in \u001b[0;36mConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    459\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 460\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/conv.py:456\u001b[0m, in \u001b[0;36mConv2d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m    452\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m    453\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv2d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m    454\u001b[0m                     weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m    455\u001b[0m                     _pair(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 456\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    457\u001b[0m \u001b[43m                \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 214.00 MiB. GPU 0 has a total capacty of 11.76 GiB of which 115.19 MiB is free. Process 652436 has 11.64 GiB memory in use. Of the allocated memory 10.86 GiB is allocated by PyTorch, and 498.86 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
     ]
    }
   ],
   "source": [
    "for epoch in range(num_epochs):\n",
    "    for i, batch in enumerate(data_loader):\n",
    "        #dict_keys(['audio_label', 'sr', 'mel_spec'])\n",
    "        #print(len(batch)) -- 32\n",
    "        mel_spec = [torch.tensor(sample['mel_spec']).T for sample in batch]\n",
    "        audio_label = [torch.tensor(sample['audio_label']) for sample in batch]\n",
    "        \n",
    "        padded_mel_spectrograms = pad_sequence(mel_spec, batch_first=True, padding_value=0).float().permute(0, 2, 1)\n",
    "        padded_audio_label = pad_sequence(audio_label, batch_first=True, padding_value=0).float()\n",
    "\n",
    "        x = torch.autograd.Variable(padded_mel_spectrograms.to(device, non_blocking=True))\n",
    "        y = torch.autograd.Variable(padded_audio_label.to(device, non_blocking=True))\n",
    "        y = y.unsqueeze(1)\n",
    "        y_g_hat = generator(x)\n",
    "\n",
    "        optim_d.zero_grad()\n",
    "        y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())\n",
    "        loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)\n",
    "\n",
    "        y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())\n",
    "        loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)\n",
    "\n",
    "        loss_disc_all = loss_disc_s + loss_disc_f\n",
    "\n",
    "        loss_disc_all.backward()\n",
    "        optim_d.step()\n",
    "\n",
    "        # Generator\n",
    "        optim_g.zero_grad()\n",
    "\n",
    "        loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45\n",
    "\n",
    "        y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)\n",
    "        y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)\n",
    "        loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)\n",
    "        loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)\n",
    "        loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)\n",
    "        loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)\n",
    "        loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel\n",
    "\n",
    "        loss_gen_all.backward()\n",
    "        optim_g.step()\n",
    "        steps += 1\n",
    "        if steps % stdout_interval == 0:\n",
    "            with torch.no_grad():\n",
    "                mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()\n",
    "    \n",
    "            print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.\n",
    "                  format(steps, loss_gen_all, mel_error, time.time() - start_b))\n",
    "    \n",
    "        # checkpointing\n",
    "        if steps % checkpoint_interval == 0 and steps != 0:\n",
    "            checkpoint_path = \"{}/g_{:08d}\".format(checkpoint_path, steps)\n",
    "            save_checkpoint(checkpoint_path,\n",
    "                            {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})\n",
    "            checkpoint_path = \"{}/do_{:08d}\".format(checkpoint_path, steps)\n",
    "            save_checkpoint(checkpoint_path, \n",
    "                            {'mpd': (mpd.module if h.num_gpus > 1\n",
    "                                                 else mpd).state_dict(),\n",
    "                             'msd': (msd.module if h.num_gpus > 1\n",
    "                                                 else msd).state_dict(),\n",
    "                             'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,\n",
    "                             'epoch': epoch})\n",
    "\n",
    "        break\n",
    "    scheduler_g.step()\n",
    "    scheduler_d.step()\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91deacad-76cf-48b0-a829-ab8cc3e27810",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64b6278b-f321-4a39-ac76-c89a7ee58177",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
